%% solve_HJB.m
%Lee and Maxted; Updated for paper here
%Code here based on Ben Moll's code repository
% -------------------------------------------------------------------------


% -------------------------------------------------------------------------
%Initial guess
v0 = (1/rho)*((zz + r.*bb + wedgeCC.*bb_negative).^(1-gamma)-1)/(1-gamma);

v = v0;

for n=1:maxit
    V = v;
    V_n(:,:,n)=V;

    % forward difference
    dVf(1:Nb-1,:) = (V(2:Nb,:)-V(1:Nb-1,:))/db;
    dVf(Nb,:) = (z + r.*bmax + wedgeCC.*bmax.*(bmax<0)).^(-gamma);
    % backward difference
    dVb(2:Nb,:) = (V(2:Nb,:)-V(1:Nb-1,:))/db;
    dVb(1,:) = (z + r.*bmin + wedgeCC.*bmin.*(bmin<0)).^(-gamma);
    
    %consumption and savings with forward difference
    cf = max(dVf, 10^(-10)).^(-1/gamma);
    ssf = zz + r.*bb + wedgeCC.*bb_negative - cf;
    %consumption and savings with backward difference
    cb = max(dVb, 10^(-10)).^(-1/gamma);
    ssb = zz + r.*bb + wedgeCC.*bb_negative - cb;
    %consumption at steady state
    c0 = zz + r.*bb + wedgeCC.*bb_negative;

    %Use Hamiltonians as tiebreaker
    Hb = (cb.^(1-gamma))/(1-gamma) + dVb.*ssb;
    Hf = (cf.^(1-gamma))/(1-gamma) + dVf.*ssf;

    Ineither = (1-(ssf>0)) .* (1-(ssb<0));
    Iunique = (ssb<0).*(1-(ssf>0)) + (1-(ssb<0)).*(ssf>0);
    Iboth = (ssb<0).*(ssf>0);
    Ib = Iunique.*(ssb<0) + Iboth.*(Hb>Hf);
    If = Iunique.*(ssf>0) + Iboth.*(Hf>=Hb);
    I0 = Ineither;
    
    c = cf.*If + cb.*Ib + c0.*I0;
    u = (c.^(1-gamma)-1)/(1-gamma);
    
    %CONSTRUCT MATRIX
    X = -Ib.*ssb/db;
    Y = -If.*ssf/db + Ib.*ssb/db;
    Z = If.*ssf/db;
    
    A = sparse([]);
    for nz=1:Nz
        A = [A; sparse(Nb, Nb*(nz-1)), ...
            spdiags(Y(:,nz),0,Nb,Nb)+spdiags(X(2:Nb,nz),-1,Nb,Nb)+spdiags([0;Z(1:Nb-1,nz)],1,Nb,Nb), ...
            sparse(Nb, Nb*(Nz-nz))];
    end
    A = A + Aswitch;
    
    if max(abs(sum(A,2)))>10^(-12)
        disp('Improper Transition Matrix')
        break
    end
    
    B = (rho + 1/Delta)*speye(Nz*Nb) - A;
    
    u_stacked = reshape(u,Nb*Nz,1);
    V_stacked = reshape(V,Nb*Nz,1);
    
    updateRHS = u_stacked + V_stacked/Delta;
    V_stacked = B\updateRHS;
    
    V = reshape(V_stacked, Nb, Nz);  
    
    Vchange = V - v;
    v = V;

    dist(n) = max(max(abs(Vchange)));
    if dist(n)<crit
        disp('Value Function Converged, Iteration = ')
        disp(n)
        break
    end
end


%% Adjust for Present Bias
A1 = A;
c1 = c;
bdot1 = zz + r.*bb + wedgeCC.*bb_negative - c1;
c = ((betaE/beta)^(1/gamma))*(1/psiE)*c1;
bdot = zz + r.*bb + wedgeCC.*bb_negative - c;
    bdot1(1,:) = max(bdot1(1,:), 1e-10);
    bdot(1,:) = max(bdot(1,:), 1e-10);

X = - min(bdot,0)/db;
Y = - max(bdot,0)/db + min(bdot,0)/db;
Z = max(bdot,0)/db;

A = sparse([]);
for nz=1:Nz
    A = [A; sparse(Nb, Nb*(nz-1)), ...
        spdiags(Y(:,nz),0,Nb,Nb)+spdiags(X(2:Nb,nz),-1,Nb,Nb)+spdiags([0;Z(1:Nb-1,nz)],1,Nb,Nb), ...
        sparse(Nb, Nb*(Nz-nz))];
end
A = A + Aswitch;


%% Calculate V for present-biased agent
vOld = v; 
for n=1:maxit
    V = v;

    u = (c.^(1-gamma)-1)/(1-gamma);
    
    B = (rho + 1/Delta)*speye(Nz*Nb) - A;
    
    u_stacked = reshape(u,Nb*Nz,1);
    V_stacked = reshape(V,Nb*Nz,1);
    
    updateRHS = u_stacked + V_stacked/Delta;
    V_stacked = B\updateRHS;
    
    V = reshape(V_stacked, Nb, Nz);  
    
    Vchange = V - v;
    v = V;

    dist(n) = max(max(abs(Vchange)));
    if dist(n)<crit && n > 1
        disp('Value Function Converged, Iteration = ')
        disp(n)
        break
    end
end

% ------------------------------------------------------------------------- 

